{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "52c47f3c",
   "metadata": {},
   "source": [
    "# 一、标签在文件夹上的数据加载\n",
    "\n",
    ">数据集文件夹  \n",
    ">>标签(文件夹）  \n",
    ">>>图片.jpg  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec3bdd36",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import Dataset, TensorDataset, DataLoader #数据集打开加载工具\n",
    "\n",
    "from torchvision import transforms,utils,datasets # 图片增强工具\n",
    "\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ad2fa82",
   "metadata": {},
   "source": [
    "## 1.定义transform增强数据\n",
    "\n",
    "```PYTHON\n",
    "from torchvision import transforms\n",
    "```\n",
    "\n",
    "**1.定义transform参数**\n",
    "```python\n",
    "trans变量 = transforms.Compose([\n",
    "    # 缩放图片Image，保持长宽比不变，最短边为__像素\n",
    "    transforms.Resize(缩放像素), \n",
    "    # 从图片中间切出__*__图片\n",
    "    transforms.CenterCrop(像素), \n",
    "    # 将图片Image转成Tensor格式(通道,高,宽),并归一化至[0,1]\n",
    "    transforms.ToTensor(),        \n",
    "])\n",
    "```\n",
    "\n",
    "**2.用transform变量变换处理图片**\n",
    "```python\n",
    "图片img = trans变量(图片img)\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "190422a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_transform = transforms.Compose([\n",
    "    transforms.Resize(32),  # 缩放图片Image，保持长宽比不变，最短边为32像素\n",
    "    transforms.CenterCrop(32),  #从图片中间切出32*32图片\n",
    "    transforms.ToTensor(),  # 将图片Image转成Tensor格式(通道,高,宽),并归一化至[0,1]\n",
    "])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7f095d9",
   "metadata": {},
   "source": [
    "## 2.使用ImageFolder进行数据集导入(生成数据集dataset)\n",
    "\n",
    "```python\n",
    "from torchvision import datasets\n",
    "```\n",
    "\n",
    "**1.不使用transform增强导入生成数据集**\n",
    "```python\n",
    "数据集变量 = datasets.ImageFolder(\n",
    "    root = './训练数据根文件夹路径',  \n",
    "    transform = None          \n",
    ")\n",
    "```\n",
    "\n",
    "**2.使用transform增强导入生成数据集**\n",
    "```python\n",
    "数据集变量 = datasets.ImageFolder(\n",
    "    root = './训练数据根文件夹路径',  \n",
    "    transform = trans变量          \n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b25469ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Transform增强后数据集\n",
    "\n",
    "dataset = datasets.ImageFolder(\n",
    "    root = './horses-or-human',   # 训练数据根文件夹路径\n",
    "    transform = data_transform    # transform增强调用\n",
    ")\n",
    "\n",
    "# 查看数据集第一个样例\n",
    "\n",
    "img, label = dataset[0]\n",
    "print(label)\n",
    "print(img.shape)\n",
    "img"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72e6e891",
   "metadata": {},
   "source": [
    "# 3.加载数据集（loader）\n",
    "\n",
    "```python\n",
    "from torch.utils.data import DataLoader \n",
    "```\n",
    "\n",
    "```python\n",
    "loader变量 = DataLoader(\n",
    "    数据集, \n",
    "    batch_size = 一次训练图片数, \n",
    "    shuffle = True/False是否随机选择数据\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93d88656",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_loader = DataLoader(dataset, batch_size=256, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3fab84ea",
   "metadata": {},
   "source": [
    "---\n",
    "---\n",
    "---\n",
    "# 二、标签在图片命名上的数据加载"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afcff669",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import Dataset, TensorDataset, DataLoader #数据集打开加载工具\n",
    "\n",
    "from torchvision import transforms,utils,datasets # 图片增强工具\n",
    "\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd14390f",
   "metadata": {},
   "source": [
    "## 1.定义transform增强数据\n",
    "\n",
    "```PYTHON\n",
    "from torchvision import transforms\n",
    "```\n",
    "\n",
    "**1.定义transform参数**\n",
    "```python\n",
    "trans变量 = transforms.Compose([\n",
    "    # 缩放图片Image，保持长宽比不变，最短边为__像素\n",
    "    transforms.Resize(缩放像素), \n",
    "    # 从图片中间切出__*__图片\n",
    "    transforms.CenterCrop(像素), \n",
    "    # 将图片Image转成Tensor格式(通道,高,宽),并归一化至[0,1]\n",
    "    transforms.ToTensor(),        \n",
    "])\n",
    "```\n",
    "\n",
    "**2.用transform变量变换处理图片**\n",
    "```python\n",
    "图片img = trans变量(图片img)\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abd73b68",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_transform = transforms.Compose([\n",
    "    transforms.Resize(32),  # 缩放图片Image，保持长宽比不变，最短边为32像素\n",
    "    transforms.CenterCrop(32),  #从图片中间切出32*32图片\n",
    "    transforms.ToTensor(),  # 将图片Image转成Tensor格式(通道,高,宽),并归一化至[0,1]\n",
    "])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80f6c35e",
   "metadata": {},
   "source": [
    "## 2.利用Dataset类进行数据导入(重载Dataset)\n",
    "```python\n",
    "from torch.utils.data import Dataset, TensorDataset, DataLoader\n",
    "from PIL import Image\n",
    "```\n",
    "\n",
    "**1.重载Dataset类**\n",
    "```python\n",
    "class Dataset自定义类名(Dataset):\n",
    "```\n",
    "\n",
    "**2.初始化属性**\n",
    "```python\n",
    "def __init__(self, path_dir, transform=None):\n",
    "    self.path_dir = path_dir # 文件路径\n",
    "    self.transform = transform #对图形预处理\n",
    "    self.images = os.listdir(self.path_dir) # 把路径下的所有文件放入一个列表\n",
    "\n",
    "def __len__(self):\n",
    "    return len(self.images) # 返回数据集大小\n",
    "```\n",
    "\n",
    "**3.根据索引index返回图像及标签**\n",
    "```python\n",
    "def __getitem__(self, index):\n",
    "    # 获取图片\n",
    "    image_index = self.images[index] # 根据索引获取图像文件名\n",
    "    img_path = os.path.join(self.path_dir, image_index)\n",
    "    img = Image.open(img_path).convert('图像通道')  # 图像通道：RGB\n",
    "    if self.transform is not None:\n",
    "        img = self.transform(img)\n",
    "            \n",
    "    # 获取标签，并将标签用数字int编码\n",
    "    label = img_path.split('\\\\')[-1].split('.')[0]\n",
    "    ...需要改写...\n",
    "    if '图片标签' in label:\n",
    "        label = 0\n",
    "    elif '图片标签' in label:\n",
    "        label = 1\n",
    "    ...需要改写...\n",
    "        \n",
    "    return img,label\n",
    "```\n",
    "\n",
    "**5.调用重载类**\n",
    "```python\n",
    "dataset = Dataset自定义类名('文件夹路径',trans变量)\n",
    "```\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d1af0c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyDataset(Dataset):\n",
    "    #初始化属性\n",
    "    def __init__(self, path_dir, transform=None):\n",
    "        self.path_dir = path_dir # 文件路径\n",
    "        self.transform = transform #对图形预处理\n",
    "        self.images = os.listdir(self.path_dir) # 把路径下的所有文件放入一个列表\n",
    "        \n",
    "    #返回整个数据集大小\n",
    "    def __len__(self):\n",
    "        return len(self.images)\n",
    "    \n",
    "    #根据索引index、返回图像及标签\n",
    "    def __getitem__(self, index):\n",
    "        # 获取图片\n",
    "        image_index = self.images[index] # 根据索引获取图像文件名\n",
    "        img_path = os.path.join(self.path_dir, image_index)\n",
    "        img = Image.open(img_path).convert('RGB')   #读取图像\n",
    "        if self.transform is not None:\n",
    "            img = self.transform(img)\n",
    "            \n",
    "        # 获取标签\n",
    "        label = img_path.split('\\\\')[-1].split('.')[0]\n",
    "        if 'horse' in label:\n",
    "            label = 1\n",
    "        elif 'human' in label:\n",
    "            label = 0 \n",
    "        \n",
    "        return img,label\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b37542fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = MyDataset('E:\\code\\jupyter\\py-torch-learning-notes-master\\实例2\\horses-or-human\\humans')\n",
    "\n",
    "img,label = dataset[5]\n",
    "print(label)\n",
    "img\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c187758b",
   "metadata": {},
   "source": [
    "# 加载数据集\n",
    "\n",
    "```python\n",
    "from torch.utils.data import DataLoader \n",
    "```\n",
    "\n",
    "```python\n",
    "loader变量 = DataLoader(\n",
    "    数据集, \n",
    "    batch_size = 一次训练图片数, \n",
    "    shuffle = True/False是否随机选择数据\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b346ee7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_loader = DataLoader(dataset, batch_size=256, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c54f7153",
   "metadata": {},
   "source": [
    "---\n",
    "---\n",
    "---\n",
    "# 三、创建自己的数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f75c7e23",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset, TensorDataset ,DataLoader"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03b4256e",
   "metadata": {},
   "source": [
    "## 1.将数据转为tensor格式\n",
    "`数据 = torch.tensor(mumpy数据)` "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50334207",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_images = torch.tensor(train_images)\n",
    "train_labels = torch.tensor(train_labels)\n",
    "\n",
    "test_images = torch.tensor(test_images)\n",
    "test_labels = torch.tensor(test_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "66c421d3",
   "metadata": {},
   "source": [
    "## 2.数据处理\n",
    "\n",
    ">图像数据处理：\n",
    ">>**图像数据列表维度shape：[图像数量,通道维数,图像长像素,图像宽像素]**  \n",
    ">>缺少通道维黑白图像处理:`图片样本data = Variable(torch.unsqueeze(图片样本data, dim=1), volatile=True).type(torch.FloatTensor)/255`  \n",
    ">>数据类型转换:`数据变量 = 数据变量.type(torch.FloatTensor)`\n",
    "\n",
    ">标签处理\n",
    ">>转换为one-hot编码:`标签labels = utils.to_categorical(标签labels)`  \n",
    ">>标签转换成long数据格式：`标签labels = 标签labels.long()`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6586f41e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 将标签转换成long格式\n",
    "train_labels = train_labels.long()\n",
    "test_labels = test_labels.long()\n",
    "\n",
    "# 图像数据调整增加维度 [图片数, 长, 宽]->[图片数, 通道数, 长, 宽], 将数据转为tensor的Float格式\n",
    "train_images = Variable(torch.unsqueeze(train_images, dim=1), volatile=True).type(torch.FloatTensor)/255\n",
    "test_images = Variable(torch.unsqueeze(test_images, dim=1), volatile=True).type(torch.FloatTensor)/255"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "756339bb",
   "metadata": {},
   "source": [
    "## 3.生成数据集\n",
    "`数据集 = TensorDataset(样本data, 标签labels)`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54d5b702",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = TensorDataset(train_images, train_labels)\n",
    "test_dataset = TensorDataset(test_images, test_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1fa1336c",
   "metadata": {},
   "source": [
    "## 4.加载数据集\n",
    "`train_loader = DataLoader(train_dataset, batch_size=120)`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e6c3c25",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)\n",
    "test_loader =DataLoader(test_dataset, batch_size=120)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
